#!/usr/bin/env python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference."""

import pytest
import torch

from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule  # noqa: E402
from lerobot.policies.factory import make_pre_post_processors  # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig  # noqa: E402
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig  # noqa: F401
from lerobot.utils.random_utils import set_seed  # noqa: E402
from tests.utils import require_cuda, require_package  # noqa: E402


@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization():
    from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy  # noqa: F401

    """Test SmolVLA policy can initialize RTC processor."""
    set_seed(42)

    config = SmolVLAConfig(max_action_dim=7, chunk_size=50)

    # Add RTC config
    config.rtc_config = RTCConfig(
        enabled=True,
        execution_horizon=10,
        max_guidance_weight=5.0,
        prefix_attention_schedule=RTCAttentionSchedule.EXP,
        debug=False,
    )

    config.input_features = {
        "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
        "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
    }
    config.output_features = {
        "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
    }

    # Instantiate policy
    policy = SmolVLAPolicy(config)

    # Verify RTC processor is initialized
    assert hasattr(policy, "rtc_processor")
    assert policy.rtc_processor is not None
    assert policy.rtc_processor.rtc_config.enabled is True

    print("✓ SmolVLA RTC initialization: Test passed")


@require_package("transformers")
@require_cuda
def test_smolvla_rtc_initialization_without_rtc_config():
    from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy  # noqa: F401

    """Test SmolVLA policy can initialize without RTC config."""
    set_seed(42)

    config = SmolVLAConfig(max_action_dim=7, chunk_size=50)

    # Instantiate policy
    policy = SmolVLAPolicy(config)

    # Verify RTC processor is not initialized
    assert hasattr(policy, "rtc_processor")
    assert policy.rtc_processor is None
    assert policy.model.rtc_processor is None
    assert policy._rtc_enabled() is False

    print("✓ SmolVLA RTC initialization without RTC config: Test passed")


@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk():
    from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy  # noqa: F401

    """Test SmolVLA policy inference with RTC and previous chunk."""
    set_seed(42)

    config = SmolVLAConfig(max_action_dim=7, chunk_size=50)

    # Add RTC config
    config.rtc_config = RTCConfig(
        enabled=True,
        execution_horizon=10,
        max_guidance_weight=5.0,
        prefix_attention_schedule=RTCAttentionSchedule.EXP,
        debug=False,
    )

    config.input_features = {
        "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
        "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
    }
    config.output_features = {
        "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
    }

    # Create dataset stats
    dataset_stats = {
        "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
        "action": {"mean": torch.zeros(7), "std": torch.ones(7)},
        "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
    }

    # Instantiate policy and create preprocessor
    policy = SmolVLAPolicy(config)
    policy.eval()
    preprocessor, _ = make_pre_post_processors(
        policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
    )

    device = config.device

    # Create dummy batch
    batch = {
        "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
        "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
        "task": ["Pick up the object"],
    }
    batch = preprocessor(batch)

    # Create previous chunk
    prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)

    with torch.no_grad():
        # Use same noise for fair comparison
        noise = policy.model.sample_noise((1, config.chunk_size, 7), device)

        # Test with RTC and previous chunk
        actions_with_rtc = policy.predict_action_chunk(
            batch,
            noise=noise.clone(),
            prev_chunk_left_over=prev_chunk,
            inference_delay=4,
            execution_horizon=10,
        )

        # Test without RTC for comparison
        policy.config.rtc_config.enabled = False
        actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
        policy.config.rtc_config.enabled = True

    # Verify shapes
    assert actions_with_rtc.shape == (1, config.chunk_size, 7)
    assert actions_without_rtc.shape == (1, config.chunk_size, 7)

    # With previous chunk, actions should be different (RTC guidance applied)
    assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)

    print("✓ SmolVLA RTC inference with prev_chunk: Test passed")


@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk():
    from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy  # noqa: F401

    """Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
    set_seed(42)

    config = SmolVLAConfig(max_action_dim=7, chunk_size=50)

    # Add RTC config
    config.rtc_config = RTCConfig(
        enabled=True,
        execution_horizon=10,
        max_guidance_weight=5.0,
        prefix_attention_schedule=RTCAttentionSchedule.EXP,
        debug=False,
    )

    config.input_features = {
        "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
        "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
    }
    config.output_features = {
        "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
    }

    # Create dataset stats
    dataset_stats = {
        "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
        "action": {"mean": torch.zeros(7), "std": torch.ones(7)},
        "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
    }

    # Instantiate policy and create preprocessor
    policy = SmolVLAPolicy(config)
    policy.eval()
    preprocessor, _ = make_pre_post_processors(
        policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
    )

    device = config.device

    # Create dummy batch
    batch = {
        "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
        "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
        "task": ["Pick up the object"],
    }
    batch = preprocessor(batch)

    with torch.no_grad():
        # Use same noise for fair comparison
        noise = policy.model.sample_noise((1, config.chunk_size, 7), device)

        # Test with RTC enabled but no previous chunk
        actions_with_rtc_no_prev = policy.predict_action_chunk(
            batch,
            noise=noise.clone(),
            prev_chunk_left_over=None,
        )

        # Test without RTC
        policy.config.rtc_config.enabled = False
        actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
        policy.config.rtc_config.enabled = True

    # Without previous chunk, RTC should have no effect
    assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)

    print("✓ SmolVLA RTC inference without prev_chunk: Test passed")


@require_package("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules():
    from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy  # noqa: F401

    """Test SmolVLA policy with RTC follows all three validation rules."""
    set_seed(42)

    config = SmolVLAConfig(max_action_dim=7, chunk_size=50)

    # Add RTC config
    config.rtc_config = RTCConfig(
        enabled=True,
        execution_horizon=10,
        max_guidance_weight=5.0,
        prefix_attention_schedule=RTCAttentionSchedule.EXP,
        debug=False,
    )

    config.input_features = {
        "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
        "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
    }
    config.output_features = {
        "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
    }

    # Create dataset stats
    dataset_stats = {
        "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
        "action": {"mean": torch.zeros(7), "std": torch.ones(7)},
        "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
    }

    # Instantiate policy and create preprocessor
    policy = SmolVLAPolicy(config)
    policy.eval()
    preprocessor, _ = make_pre_post_processors(
        policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
    )

    device = config.device

    # Create dummy batch
    batch = {
        "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
        "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
        "task": ["Pick up the object"],
    }
    batch = preprocessor(batch)

    # Create previous chunk
    prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)

    inference_delay = 4
    execution_horizon = 10

    with torch.no_grad():
        # Use same noise for fair comparison
        noise = policy.model.sample_noise((1, config.chunk_size, 7), device)

        # Test with RTC
        actions_with_rtc = policy.predict_action_chunk(
            batch,
            noise=noise.clone(),
            prev_chunk_left_over=prev_chunk,
            inference_delay=inference_delay,
            execution_horizon=execution_horizon,
        )

        # Test without RTC
        policy.config.rtc_config.enabled = False
        actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
        policy.config.rtc_config.enabled = True

    assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
